<!DOCTYPE html>
<html lang="en-US">
<head>
    <meta charset="utf-8"/>
    <meta content="IE=edge" http-equiv="X-UA-Compatible"/>
    <meta content="width=device-width, initial-scale=1" name="viewport"/>
    <link href="/assets/css/style.css" rel="stylesheet">
    </link>
</head>
</html>
<body>
<div class="wrapper">
    <header>
        <div style="display: flex; align-items: center; height: 100%; margin-bottom: 5px;">
            <img src="/assets/img/backpack_logo_torch.svg"
                 style="width: 48px; margin-right:10px">
            <h1  style="margin-top:auto; margin-bottom:auto; display:block">
                <a href="https://f-dangel.github.io/backpack/">
                    BackPACK
                </a>
            </h1>
        </div>
        <p>
            Get more out of your backward pass
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="examples.html">
                    BackPACK on a small example
                </a>
            </li>
        </ul>
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="https://docs.backpack.pt">
                    Documentation
                </a>
            </li>
        </ul>
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="https://github.com/f-dangel/backpack">
                    Github repo
                </a>
            </li>
        </ul>
        </p>
    </header>
    <section>
        <h1 id="backpack-on-a-small-example">BackPACK on a small example</h1>

<p>This small example shows how to use BackPACK to implement a simple second-order optimizer.
It follows <a href="https://github.com/pytorch/examples/tree/master/mnist">the traditional PyTorch MNIST example</a>.</p>

<h2 id="installation">Installation</h2>

<p>For this example to run, you will need <a href="https://pytorch.org/get-started/locally/">PyTorch and TorchVision (&gt;= 1.0)</a>
To install BackPACK, either use <code class="language-plaintext highlighter-rouge">pip</code> or <a href="https://github.com/f-dangel/backpack">clone the repo</a>.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip install backpack-for-pytorch
</code></pre></div></div>

<h2 id="an-example-diagonal-ggn-preconditioner">An example: Diagonal GGN Preconditioner</h2>

<p>You can find the code 
<a href="https://docs.backpack.pt/en/master/use_cases/example_diag_ggn_optimizer.html">in the documentation</a>.
It runs SGD with a preconditioner based on the diagonal of the GGN.</p>

<h3 id="step-1-libraries-mnist-and-the-model">Step 1: Libraries, MNIST, and the model</h3>

<p>Let’s start with the imports and setting some hyperparameters. 
In addition to PyTorch and TorchVision, 
we’re going to load the main components we need from BackPACK:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torchvision</span>
<span class="c1"># The main BackPACK functionalities
</span><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">backpack</span><span class="p">,</span> <span class="n">extend</span>
<span class="c1"># The diagonal GGN extension
</span><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">DiagGGNMC</span>
<span class="c1"># This layer did not exist in Pytorch 1.0
</span><span class="kn">from</span> <span class="nn">backpack.core.layers</span> <span class="kn">import</span> <span class="n">Flatten</span>

<span class="c1"># Hyperparameters
</span><span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">STEP_SIZE</span> <span class="o">=</span> <span class="mf">0.01</span>
<span class="n">DAMPING</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">MAX_ITER</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<p>Now, let’s load MNIST</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">mnist_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">dataloader</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
    <span class="n">torchvision</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span>
        <span class="s">'./data'</span><span class="p">,</span>
        <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
        <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
        <span class="n">transform</span><span class="o">=</span><span class="n">torchvision</span><span class="p">.</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span>
            <span class="n">torchvision</span><span class="p">.</span><span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">(),</span>
            <span class="n">torchvision</span><span class="p">.</span><span class="n">transforms</span><span class="p">.</span><span class="n">Normalize</span><span class="p">(</span>
                <span class="p">(</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,)</span>
            <span class="p">)</span>
        <span class="p">])),</span>
    <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span>
    <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span>
<span class="p">)</span>

</code></pre></div></div>

<p>We’ll create a small CNN with MaxPooling and ReLU activations, using a <a href="https://pytorch.org/docs/stable/nn.html#sequential"><code class="language-plaintext highlighter-rouge">Sequential</code></a> layer as the main model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
    <span class="n">Flatten</span><span class="p">(),</span> 
    <span class="c1"># Pytorch &lt;1.2 doesn't have a Flatten layer
</span>    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="o">*</span><span class="mi">4</span><span class="o">*</span><span class="mi">50</span><span class="p">,</span> <span class="mi">500</span><span class="p">),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span>
<span class="p">)</span>

</code></pre></div></div>

<p>We will also need a loss function and a way to measure accuracy</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">loss_function</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">get_accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
    <span class="s">"""Helper function to print the accuracy"""</span>
    <span class="n">predictions</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">view_as</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">predictions</span><span class="p">.</span><span class="n">eq</span><span class="p">(</span><span class="n">targets</span><span class="p">).</span><span class="nb">float</span><span class="p">().</span><span class="n">mean</span><span class="p">().</span><span class="n">item</span><span class="p">()</span><span class="sb">``</span><span class="err">`</span>

</code></pre></div></div>

<h3 id="step-2-the-optimizer">Step 2: The optimizer</h3>

<p>The update rule we want to implement is a precondionned gradient descent, 
using the diagonal of the generalized Gauss-Newton,</p>

<center>
<img src="assets/img/updaterule.png" width="60%" />
</center>

<p>where <code class="language-plaintext highlighter-rouge">𝛼</code> is the step-size, <code class="language-plaintext highlighter-rouge">𝜆</code> is the damping parameter, <code class="language-plaintext highlighter-rouge">g</code> is the gradient and <code class="language-plaintext highlighter-rouge">G</code> is the diagonal of the generalized Gauss-Newton (GGN).
The difficult part is computing <code class="language-plaintext highlighter-rouge">G</code>, but BackPACK will do this;
just like PyTorch’s autograd compute the gradient for each parameter <code class="language-plaintext highlighter-rouge">p</code> and store it in <code class="language-plaintext highlighter-rouge">p.grad</code>, BackPACK with the <code class="language-plaintext highlighter-rouge">DiagGGNMC</code> extension will compute (a Monte-Carlo estimate of) the diagonal of the GGN and store it in <code class="language-plaintext highlighter-rouge">p.diag_ggn_mc</code>.
We can now simply focus on implementing the optimizer that uses this information:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DiagGGNOptimizer</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Optimizer</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">parameters</span><span class="p">,</span> <span class="n">step_size</span><span class="p">,</span> <span class="n">damping</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span>
            <span class="n">parameters</span><span class="p">,</span> 
            <span class="nb">dict</span><span class="p">(</span><span class="n">step_size</span><span class="o">=</span><span class="n">step_size</span><span class="p">,</span> <span class="n">damping</span><span class="o">=</span><span class="n">damping</span><span class="p">)</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">param_groups</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s">"params"</span><span class="p">]:</span>
                <span class="n">step_direction</span> <span class="o">=</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="o">/</span> <span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">diag_ggn_mc</span> <span class="o">+</span> <span class="n">group</span><span class="p">[</span><span class="s">"damping"</span><span class="p">])</span>
                <span class="n">p</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s">"step_size"</span><span class="p">],</span> <span class="n">step_direction</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">loss</span>
</code></pre></div></div>

<h3 id="step-3-put-on-your-backpack">Step 3: Put on your BackPACK</h3>

<p>The last thing to do before running the optimizer is (i) tell BackPACK about the model and loss function used and (ii) create the optimizer.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">extend</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="n">extend</span><span class="p">(</span><span class="n">loss_function</span><span class="p">)</span>

<span class="n">optimizer</span> <span class="o">=</span> <span class="n">DiagGGNOptimizer</span><span class="p">(</span>
    <span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> 
    <span class="n">step_size</span><span class="o">=</span><span class="n">STEP_SIZE</span><span class="p">,</span> 
    <span class="n">damping</span><span class="o">=</span><span class="n">DAMPING</span>
<span class="p">)</span>
</code></pre></div></div>

<p>We are now ready to run!</p>

<h3 id="the-main-loop">The main loop</h3>

<p>Traditional optimization loop: load each minibatch, 
compute the minibatch loss, but now call BackPACK before doing the backward pass.
The <code class="language-plaintext highlighter-rouge">diag_ggn_mc</code> fields of the parameters will get filled and the optimizer will run.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">mnist_loader</span><span class="p">):</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

    <span class="n">accuracy</span> <span class="o">=</span> <span class="n">get_accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>

    <span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">DiagGGNMC</span><span class="p">()):</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_function</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
        <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

    <span class="k">print</span><span class="p">(</span>
        <span class="s">"Iteration %3.d/%d   "</span> <span class="o">%</span> <span class="p">(</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">MAX_ITER</span><span class="p">)</span> <span class="o">+</span>
        <span class="s">"Minibatch Loss %.3f  "</span> <span class="o">%</span> <span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="o">+</span>
        <span class="s">"Accuracy %.0f"</span> <span class="o">%</span> <span class="p">(</span><span class="n">accuracy</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span> <span class="o">+</span> <span class="s">"%"</span>
    <span class="p">)</span>

    <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">&gt;=</span> <span class="n">MAX_ITER</span><span class="p">:</span>
        <span class="k">break</span>
</code></pre></div></div>

<p>If everything went fine, the output should look like</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Iteration   0/100   Minibatch Loss 2.307   Accuracy 12%
Iteration   1/100   Minibatch Loss 2.318   Accuracy 8%
Iteration   2/100   Minibatch Loss 2.329   Accuracy 8%
Iteration   3/100   Minibatch Loss 2.281   Accuracy 19%
Iteration   4/100   Minibatch Loss 2.265   Accuracy 19%
...
Iteration  96/100   Minibatch Loss 0.319   Accuracy 86%
Iteration  97/100   Minibatch Loss 0.435   Accuracy 89%
Iteration  98/100   Minibatch Loss 0.330   Accuracy 94%
Iteration  99/100   Minibatch Loss 1.227   Accuracy 89%
Iteration 100/100   Minibatch Loss 0.173   Accuracy 95%
</code></pre></div></div>

    </section>
    <footer>
        <p>
            <small>
                Hosted on GitHub Pages — Theme by
                <a href="https://github.com/orderedlist">
                    orderedlist
                </a>
            </small>
        </p>
    </footer>
</div>
<script src="/assets/js/scale.fix.js">
</script>
</body>
